{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys\n",
    "import os\n",
    "os.chdir('E:\\GitHub\\QA-abstract-and-reasoning')\n",
    "sys.path.append('E:\\GitHub\\QA-abstract-and-reasoning')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "from seq2seq_tf2.model import Seq2Seq\n",
    "from seq2seq_tf2.train_helper import train_model, get_train_msg\n",
    "from seq2seq_tf2.batcher import batcher\n",
    "from utils.config_gpu import config_gpu\n",
    "from utils.params import get_params\n",
    "from utils.saveLoader import Vocab\n",
    "from utils.config import SEQ2SEQ_CKPT\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run utils/params.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 Physical GPUs, 1 Logical GPUs\n",
      "Building the model ...\n",
      "Restored from E:\\GitHub\\QA-abstract-and-reasoning\\data\\checkpoints\\seq2seq_checkpoints\\ckpt-5\n"
     ]
    }
   ],
   "source": [
    "# GPU资源配置\n",
    "config_gpu()\n",
    "# 读取vocab训练\n",
    "vocab = Vocab(params[\"vocab_path\"], params[\"vocab_size\"])\n",
    "params['vocab_size'] = vocab.count\n",
    "# 构建模型\n",
    "print(\"Building the model ...\")\n",
    "model = Seq2Seq(params)\n",
    "# 获取保存管理者\n",
    "checkpoint = tf.train.Checkpoint(Seq2Seq=model)\n",
    "checkpoint_manager = tf.train.CheckpointManager(checkpoint, SEQ2SEQ_CKPT, max_to_keep=5)\n",
    "\n",
    "checkpoint.restore(checkpoint_manager.latest_checkpoint)\n",
    "if checkpoint_manager.latest_checkpoint:\n",
    "    print(\"Restored from {}\".format(checkpoint_manager.latest_checkpoint))\n",
    "    params[\"trained_epoch\"] = get_train_msg()\n",
    "    params[\"learning_rate\"] *= np.power(0.9, params[\"trained_epoch\"])\n",
    "else:\n",
    "    print(\"Initializing from scratch.\")\n",
    "    params[\"trained_epoch\"] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = params['epochs']\n",
    "batch_size = params['batch_size']\n",
    "pad_index = vocab.word2id[vocab.PAD_TOKEN]\n",
    "start_index = vocab.word2id[vocab.START_DECODING]\n",
    "optimizer = tf.keras.optimizers.Adam(name='Adam', learning_rate=params[\"learning_rate\"])\n",
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')\n",
    "\n",
    "# 定义损失函数\n",
    "def loss_function(real, pred):\n",
    "    mask = tf.math.logical_not(tf.math.equal(real, pad_index))\n",
    "    loss_ = loss_object(real, pred)\n",
    "    mask = tf.cast(mask, dtype=loss_.dtype)\n",
    "    loss_ *= mask\n",
    "    return tf.reduce_sum(loss_) / tf.reduce_sum(mask)\n",
    "\n",
    "def loss_function_old(real, pred):\n",
    "    mask = tf.math.logical_not(tf.math.equal(real, pad_index))\n",
    "    loss_ = loss_object(real, pred)\n",
    "    mask = tf.cast(mask, dtype=loss_.dtype)\n",
    "    loss_ *= mask\n",
    "    return tf.reduce_mean(loss_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = batcher(vocab, params)\n",
    "steps_per_epoch =params[\"steps_per_epoch\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc_dec_inputs = next(iter(dataset))\n",
    "enc_input = enc_dec_inputs[\"enc_input\"]\n",
    "dec_target = enc_dec_inputs[\"target\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc_output, enc_hidden = model.call_encoder(enc_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "dec_input = tf.expand_dims([start_index] * batch_size, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "dec_hidden = enc_hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions, _ = model(dec_input, dec_hidden, enc_output, dec_target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=12585, shape=(), dtype=float32, numpy=4.1510983>"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_function(dec_target[:, 1:], predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=12627, shape=(), dtype=float32, numpy=2.223565>"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_function_old(dec_target[:, 1:], predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "# input\n",
    "real = dec_target[:, 1:]\n",
    "pred = predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = tf.math.logical_not(tf.math.equal(real, pad_index))\n",
    "loss_ = loss_object(real, pred)\n",
    "mask = tf.cast(mask, dtype=loss_.dtype)\n",
    "loss_ *= mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=12529, shape=(), dtype=float32, numpy=4.1510983>"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.reduce_sum(loss_) / tf.reduce_sum(mask) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=12535, shape=(64,), dtype=float32, numpy=\n",
       "array([32.,  7., 12., 29., 23., 28., 20., 17., 17., 21., 10., 39., 14.,\n",
       "       31., 18., 18., 30., 39.,  4.,  4., 27., 14., 14., 20., 27., 23.,\n",
       "        6., 22., 18.,  9., 20., 14., 39., 17.,  6., 12., 27., 13., 11.,\n",
       "       29., 13., 39., 14., 32., 28., 15., 14., 20., 30., 20., 13., 33.,\n",
       "       16., 22., 33., 22., 20., 35., 20., 19., 12., 39., 15., 32.],\n",
       "      dtype=float32)>"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.reduce_sum(mask, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.preprocess import get_seg_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "get_seg_data\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E:\\GitHub\\QA-abstract-and-reasoning\\utils\\preprocess.py:415: FutureWarning: The signature of `Series.to_csv` was aligned to that of `DataFrame.to_csv`, and argument 'header' will change its default value from False to True: please pass an explicit value to suppress this warning.\n",
      "  _train_seg['train_seg_y'].to_csv(TRAIN_SEG_Y, index=None)\n",
      "E:\\GitHub\\QA-abstract-and-reasoning\\utils\\preprocess.py:416: FutureWarning: The signature of `Series.to_csv` was aligned to that of `DataFrame.to_csv`, and argument 'header' will change its default value from False to True: please pass an explicit value to suppress this warning.\n",
      "  _test_seg['test_seg_x'].to_csv(TEST_SEG_X, index=None)\n",
      "E:\\GitHub\\QA-abstract-and-reasoning\\utils\\preprocess.py:417: FutureWarning: The signature of `Series.to_csv` was aligned to that of `DataFrame.to_csv`, and argument 'header' will change its default value from False to True: please pass an explicit value to suppress this warning.\n",
      "  \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "create:  E:\\GitHub\\QA-abstract-and-reasoning\\data\\train_seg_x.csv\n",
      "create:  E:\\GitHub\\QA-abstract-and-reasoning\\data\\train_seg_y.csv\n",
      "create:  E:\\GitHub\\QA-abstract-and-reasoning\\data\\test_seg_x.csv\n",
      "样本数量为: 81625\n"
     ]
    }
   ],
   "source": [
    "get_seg_data()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tf2.0]",
   "language": "python",
   "name": "conda-env-tf2.0-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
